import os
import json
import time
import datetime

import copy

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch_geometric

import nagphormer_blocks
from ssl_tasks import clustering, pairsim, dgi, partition, pairdis

# %%
LAYER_MAP = {
    'lin': torch.nn.Linear,
    'gcn': torch_geometric.nn.GCNConv,
    'sage': torch_geometric.nn.SAGEConv,
    'gat': torch_geometric.nn.GATConv,
    'gtx': torch_geometric.nn.TransformerConv,
}

ACT_MAP = {
    'relu': torch.nn.ReLU(inplace=True),
    'prelu': torch.nn.PReLU(),
}

# %%


class GNNEncoder(torch.nn.Module):
    def __init__(self, stems, backbone, device=torch.device('cuda')):
        super().__init__()
        self.configs = {
            'stems': stems,
            'backbone': backbone,
        }
        self.device_ = device

        # Construct stems
        self.stem_in_features = [stem['num_node_f'] for stem in stems]
        self.stems = torch.nn.ModuleList()
        num_stems = len(stems)
        for stem_id in range(num_stems):
            stem = stems[stem_id]
            stem_feature_sizes = [stem['num_node_f'], *stem['num_layer_features']]
            layers = []
            for i in range(1, len(stem_feature_sizes)):
                if stem['layer_type'] != 'lin':
                    layers.append(
                        (
                            LAYER_MAP[stem['layer_type']](in_channels=stem_feature_sizes[i - 1],
                                                          out_channels=stem_feature_sizes[i]),
                            'x, edge_index -> x'
                        )
                    )
                else:
                    layers.append(LAYER_MAP[stem['layer_type']](in_features=stem_feature_sizes[i - 1],
                                                                out_features=stem_feature_sizes[i]))
                layers.append(ACT_MAP[stem['act']])
            if stem['layer_type'] != 'lin':
                self.stems.append(torch_geometric.nn.Sequential('x, edge_index', layers))
            else:
                self.stems.append(torch.nn.Sequential(*layers))

        # Construct backbone
        layers = []
        backbone_feature_sizes = [backbone['num_in_features'], *backbone['num_layer_features']]
        for i in range(1, len(backbone_feature_sizes)):
            if backbone['layer_type'] in ['gtx', 'gat']:
                layers.append(
                    (
                        LAYER_MAP[backbone['layer_type']](in_channels=backbone_feature_sizes[i - 1],
                                                          out_channels=backbone_feature_sizes[i],
                                                          heads=backbone['num_heads'][i - 1]),
                        'x, edge_index -> x'
                    )
                )
                layers.append(torch.nn.Linear(in_features=layers[-1][0].out_channels * layers[-1][0].heads,
                                              out_features=backbone_feature_sizes[i]))
                layers.append(ACT_MAP[backbone['act']])
            else:
                layers.append(
                    (
                        LAYER_MAP[backbone['layer_type']](in_channels=backbone_feature_sizes[i - 1],
                                                          out_channels=backbone_feature_sizes[i]),
                        'x, edge_index -> x'
                    )
                )
                layers.append(ACT_MAP[backbone['act']])
            # layers.append(norm)
            # layers.append(torch_geometric.nn.Dropout(p=0.5))
        self.backbone = torch_geometric.nn.Sequential('x, edge_index', layers)

    def forward(self, x, edge_index):
        stem_id = self.stem_in_features.index(x.shape[-1])
        if self.configs['stems'][stem_id]['layer_type'] != 'lin':
            x = self.stems[stem_id](x, edge_index)
        else:
            x = self.stems[stem_id](x)
        x = self.backbone(x, edge_index)
        return x


class GNNClassifier(GNNEncoder):
    save_freq = 5

    def __init__(self, stems, backbone, num_features, num_classes, device, state_dict=None):
        super().__init__(stems, backbone, device)
        self.training_type = None
        self.lr_scheduler = None
        self.optimizer = None
        self.pretrained = False
        if state_dict is not None:
            self.load_state_dict(state_dict)
            self.pretrained = True
        stem_id = [stem['num_node_f'] for stem in self.configs['stems']].index(num_features)
        self.stems = self.stems[stem_id]
        if self.configs['backbone']['layer_type'] in ['gat', 'gtx']:
            self.predictor = torch.nn.Linear(in_features=self.backbone[-2].out_features,
                                             out_features=num_classes)
        else:
            self.predictor = torch.nn.Linear(in_features=self.backbone[-2].out_channels,
                                             out_features=num_classes)

    def forward(self, x, edge_index):
        if self.configs['stems'][0]['layer_type'] != 'lin':
            x = self.stems(x, edge_index)
        else:
            x = self.stems(x)
        x = self.backbone(x, edge_index)
        x = self.predictor(x)
        return x

    def train_model(self, dataset, loss, optimizer, num_epochs, freeze_encoder, save_path=None, verbose=50,
                    lr_adapt=True):
        self.optimizer = optimizer
        if freeze_encoder:
            self.freeze_encoder()
        self.train()
        self.to(self.device_)
        dataset.data.to(self.device_)
        if lr_adapt:
            self.lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, verbose=True, patience=50)
        best_loss = 1e2
        print('starting training of classifier')
        start_time = time.time()
        for epoch in range(num_epochs):
            # zero grad
            self.optimizer.zero_grad()

            # forward
            mask = dataset.data.train_mask
            preds = self.forward(x=dataset.data.x, edge_index=dataset.data.edge_index)
            train_loss = loss(preds[mask], dataset.data.y[mask])

            # backward and update
            train_loss.backward()
            self.optimizer.step()

            # save
            if save_path is not None and (epoch + 1) % self.save_freq == 0:
                mask = dataset.data.val_mask
                val_loss = loss(preds[mask], dataset.data.y[mask]).item()
                if val_loss < best_loss:
                    best_loss = val_loss
                    self.save_model(epoch, save_path)

            # print
            if (epoch + 1) % verbose == 0 or epoch == 0 or (epoch + 1) == num_epochs:
                print(
                    '[' + datetime.datetime.now().strftime(format='%d.%m.%y %H:%M:%S') + ']' +
                    ' epoch: ' + str(epoch + 1) +
                    '; classification_loss: ' + str(round(train_loss.item(), 4))
                )

            # lr schedule
            if self.lr_scheduler is not None:
                self.lr_scheduler.step(train_loss)

        # print
        print('training completed. elapsed time: ', str(time.time() - start_time))

        eval_loss, eval_acc = self.evaluate_model(dataset, loss)
        print('evaluation loss: ', str(round(eval_loss, 4)))
        print('evaluation accuracy: ', str(round(eval_acc, 4)))

        return eval_loss, eval_acc

    def freeze_encoder(self):
        for p in self.stems.parameters():
            p.requires_grad = False
        for p in self.backbone.parameters():
            p.requires_grad = False

    def save_model(self, epoch, save_path):
        save_dict = {
            'model_state_dict': self.state_dict(),
            'optimizer': {
                'type': 'adam',
                'state_dict': self.optimizer.state_dict(),
            },
            'configs': self.configs,
            'epoch': epoch,
        }
        torch.save(save_dict, save_path)

    def evaluate_model(self, dataset, loss):
        self.eval()
        mask = dataset.data.test_mask
        preds = self.forward(x=dataset.data.x, edge_index=dataset.data.edge_index)
        eval_loss = loss(preds[mask], dataset.data.y[mask])
        y_hat = preds.argmax(dim=1)
        test_correct = y_hat[mask] == dataset.data.y[mask]
        test_acc = int(test_correct.sum()) / int(mask.sum())
        return eval_loss.item(), test_acc


class SSLGNN(torch.nn.Module):
    base_lr = 1e-3
    save_freq = 20
    task_weights = {
        'graph partitioning': 0.9,
        'deep graph infomax': 0.65,
        'pair-wise attribute similarity': 0.65,
        'pair-wise distance': 0.1,
        'clustering': 0.1,
    }

    def __init__(self, encoder, data, ssl_tasks):
        super().__init__()
        # model
        self.encoder = encoder
        self.encoder.to(self.encoder.device_)

        # data
        self.data = data

        # ssl tasks
        self.ssl_tasks = ssl_tasks
        self.ssl_objects = self.instantiate_ssl()

        # optimization
        self.backbone_optimizer = torch.optim.Adam(self.encoder.backbone.parameters(), lr=self.base_lr)
        self.stem_optimizers = [torch.optim.Adam(stem.parameters(), lr=self.base_lr) for stem in self.encoder.stems]
        self.head_optimizers = [torch.optim.Adam(ssl_object.predictor.parameters(), lr=self.base_lr)
                                for ssl_object in self.ssl_objects if ssl_object.name != 'graph partitioning']
        if 'partition' in self.ssl_tasks:
            partition_task = [ssl_object for ssl_object in self.ssl_objects if ssl_object.name == 'graph partitioning']
            partition_task = partition_task[0]
            partition_optimizers = [torch.optim.Adam(predictor.parameters(), lr=self.base_lr)
                                    for predictor in partition_task.predictor]
            self.head_optimizers = [*self.head_optimizers, *partition_optimizers]
        self.plateau_lr_scheduler = {
            'stems': [
                torch.optim.lr_scheduler.ReduceLROnPlateau(stem_optimizer, patience=50, factor=1 / 3., verbose=True)
                for stem_optimizer in self.stem_optimizers
            ],
            'backbone': torch.optim.lr_scheduler.ReduceLROnPlateau(self.backbone_optimizer, patience=50, factor=1 / 3.,
                                                                   verbose=True),
            'heads': [
                torch.optim.lr_scheduler.ReduceLROnPlateau(head_optimizer, patience=50, factor=1 / 3., verbose=True)
                for head_optimizer in self.head_optimizers
            ],
        }
        self.lin_lr_scheduler = {
            'stems': [
                torch.optim.lr_scheduler.LinearLR(stem_optimizer) for stem_optimizer in self.stem_optimizers
            ],
            'backbone': torch.optim.lr_scheduler.LinearLR(self.backbone_optimizer),
            'heads': [
                torch.optim.lr_scheduler.LinearLR(head_optimizer) for head_optimizer in self.head_optimizers
            ],
        }
        self.lr_scheduler = [self.plateau_lr_scheduler, self.lin_lr_scheduler]

    def instantiate_ssl(self):
        ssl_objects = []
        if self.encoder.configs['backbone']['layer_type'] in ['gat', 'gtx']:
            embedding_size = self.encoder.backbone[-2].out_features
        else:
            embedding_size = self.encoder.backbone[-2].out_channels
        if 'clustering' in self.ssl_tasks:
            ssl_objects.append(
                clustering.ClusteringTask(data=self.data, embedding_size=embedding_size, device=self.encoder.device_)
            )
        if 'pairsim' in self.ssl_tasks:
            ssl_objects.append(
                pairsim.PairwiseAttrSimTask(data=self.data, embedding_size=embedding_size, device=self.encoder.device_)
            )
        if 'dgi' in self.ssl_tasks:
            ssl_objects.append(
                dgi.DGITask(data=self.data, encoder=self.encoder, embedding_size=embedding_size,
                            device=self.encoder.device_)
            )
        if 'partition' in self.ssl_tasks:
            ssl_objects.append(
                partition.PartitionTask(data=self.data, embedding_size=embedding_size, device=self.encoder.device_)
            )
        if 'pairdis' in self.ssl_tasks:
            ssl_objects.append(
                pairdis.PairwiseDistanceTask(data=self.data, embedding_size=embedding_size, device=self.encoder.device_)
            )
        return ssl_objects

    def pretrain(self, num_epochs=1000, weighted=False, lr_adapt=True, verbose_freq=50, save_path=None):
        if lr_adapt is False:
            self.lr_scheduler = None

        if save_path is not None:
            with open(os.path.join(save_path['folder'], 'config.txt'), 'w') as f:
                f.write(json.dumps(self.encoder.configs))
                f.write("\n" + json.dumps({'weighted': weighted}))

        best_loss = 1e2
        self.train()
        print('starting pre-training')
        start_time = time.time()
        for epoch in range(num_epochs):
            # zero grad
            self.backbone_optimizer.zero_grad()
            [head_optimizer.zero_grad() for head_optimizer in self.head_optimizers]

            # forward + backward + update stems
            for dataset in self.data.datasets:
                [stem_optimizer.zero_grad() for stem_optimizer in self.stem_optimizers]
                # embeddings
                dataset.data.to(self.encoder.device_)
                x = self.encoder(x=dataset.data.x, edge_index=dataset.data.edge_index)

                # ssl
                ssl_loss = 0
                for ssl_object in self.ssl_objects:
                    if weighted:
                        ssl_loss = ssl_loss +\
                                   self.task_weights[ssl_object.name] * ssl_object.get_loss(x, dataset.name)
                    else:
                        ssl_loss = ssl_loss + ssl_object.get_loss(x, dataset.name)

                dataset.data.to(torch.device('cpu'))  # remove data from gpu

                ssl_loss.backward()

                # update parameters: stem
                [stem_optimizer.step() for stem_optimizer in self.stem_optimizers]

            # update parameters: heads + backbone
            [head_optimizer.step() for head_optimizer in self.head_optimizers]
            self.backbone_optimizer.step()

            # save and overwrite model
            if ssl_loss.item() < best_loss and (epoch + 1) % self.save_freq == 0:
                best_loss = copy.deepcopy(ssl_loss.item())
                if save_path is not None:
                    self.save_model(epoch, save_path['file'])
                else:
                    print('no save_path provided. skipping saving')

            # print
            if (epoch + 1) % verbose_freq == 0 or epoch == 0 or epoch == num_epochs - 1:
                print(
                    '[' + datetime.datetime.now().strftime(format='%d.%m.%y %H:%M:%S') + ']' +
                    ' epoch: ' + str(epoch + 1) +
                    '; ssl_loss: ' + str(round(ssl_loss.item(), 4))
                )

            # lr schedule
            if self.plateau_lr_scheduler is not None:
                self.adjust_lr(ssl_loss)
        # print
        print('pre-training completed. elapsed time: ', str(time.time() - start_time))

    def save_model(self, epoch, save_path):
        head_state_dict = [ssl_object.predictor.state_dict() for ssl_object in self.ssl_objects
                           if ssl_object.name != 'graph partitioning']
        if 'partition' in self.ssl_tasks:
            partition_task = [ssl_object for ssl_object in self.ssl_objects if ssl_object.name == 'graph partitioning']
            partition_task = partition_task[0]
            partition_dicts = [predictor.state_dict() for predictor in partition_task.predictor]
            head_state_dict = [head_state_dict, partition_dicts]
        save_dict = {
            'data': self.data,
            'model_state_dict': {
                'heads': head_state_dict,
                'encoder': self.encoder.state_dict(),
            },
            'optimizer': {
                'type': 'adam',
                'base_lr': self.base_lr,
                'state_dicts': {
                    'backbone': self.backbone_optimizer.state_dict(),
                    'stems': [stem_optimizer.state_dict() for stem_optimizer in self.stem_optimizers],
                    'heads': [head_optimizer.state_dict() for head_optimizer in self.head_optimizers],
                },
            },
            'configs': self.encoder.configs,
            'ssl_tasks': self.ssl_tasks,
            'epoch': epoch,
        }
        torch.save(save_dict, save_path)

    def adjust_lr(self, ssl_loss):
        # stems
        [stem_scheduler.step() for stem_scheduler in self.lin_lr_scheduler['stems']]
        [stem_scheduler.step(ssl_loss) for stem_scheduler in self.plateau_lr_scheduler['stems']]

        # backbone
        self.lin_lr_scheduler['backbone'].step()
        self.plateau_lr_scheduler['backbone'].step(ssl_loss)

        # heads
        [head_scheduler.step() for head_scheduler in self.lin_lr_scheduler['heads']]
        [head_scheduler.step(ssl_loss) for head_scheduler in self.plateau_lr_scheduler['heads']]

    # def compare_state_dicts(self):
    #     s1 = [copy.deepcopy(stem.state_dict()) for stem in self.encoder.stems]
    #     b1 = copy.deepcopy(self.encoder.backbone.state_dict())
    #     s_keys = [s_.keys() for s_ in s1]
    #     b_keys = b1.keys()
    #     s_sum = [0] * len(s1)
    #     for i in range(len(s1)):
    #         s_i, s1_i = self.stem_state_dict[i], s1[i]
    #         for key in s_keys[i]:
    #             s_sum[i] = s_sum[i] + torch.allclose(s_i[key], s1_i[key])
    #         s_sum[i] = s_sum[i] / len(s_keys[i])
    #         print('stem ' + str(i) + ' changed' if s_sum[i] != 1 else 'stem ' + str(i) + ' not changed')
    #     b_sum = 0
    #     c = 0
    #     for key in b_keys:
    #         if 'running' not in key and 'num_batches' not in key:
    #             b_sum = b_sum + torch.allclose(self.backbone_state_dict[key], b1[key])
    #             c = c + 1
    #     b_sum = b_sum / c
    #     print('backbone changed' if b_sum != 1 else 'backbone not changed')

# %%


class TransformerEncoder(torch.nn.Module):
    def __init__(self, stems, backbone, num_hops=2, pos_embed_dim=15, dropout_rate=0., attention_dropout_rate=0.1,
                 device=torch.device('cuda')):
        super().__init__()
        self.configs = {
            'stems': stems,
            'backbone': backbone,
        }
        self.device_ = device
        self.sequence_length = num_hops + 1
        self.pos_embed_dim = pos_embed_dim
        self.hidden_dim = backbone['num_layer_features']
        self.ffn_dim = 2 * backbone['num_layer_features']
        self.dropout_rate = dropout_rate
        self.attention_dropout_rate = attention_dropout_rate

        # Construct stems
        self.stem_in_features = [stem['num_node_f'] for stem in stems]
        self.stems = torch.nn.ModuleList()
        for stem in stems:
            self.stems.append(torch.nn.Linear(in_features=stem['num_node_f'], out_features=stem['num_layer_features']))

        # Construct backbone
        backbone_encoders = [
            nagphormer_blocks.EncoderLayer(
                self.hidden_dim,  # attention input dim
                self.ffn_dim,  # MLP input dim
                self.dropout_rate, self.attention_dropout_rate, backbone['num_heads']
            )
            for _ in range(backbone['num_layers'])
        ]
        self.backbone = torch.nn.ModuleList(backbone_encoders)
        self.final_ln = torch.nn.LayerNorm(self.hidden_dim)

        # Construct readout
        self.out_proj = nn.Linear(self.hidden_dim, int(self.hidden_dim / 2))
        self.attn_layer = nn.Linear(2 * self.hidden_dim, 1)
        self.scaling = nn.Parameter(torch.ones(1) * 0.5)

        # Initialisation
        self.apply(lambda module: nagphormer_blocks.init_params(module, num_layers=backbone['num_layers']))

    def forward(self, x):
        stem_id = self.stem_in_features.index(x.shape[-1])
        x = self.stems[stem_id](x)

        # transformer layers
        for backbone_layer in self.backbone:
            x = backbone_layer(x)

        # layer normalisation
        x = self.final_ln(x)

        # attention-based readout
        target = x[:, 0, :].unsqueeze(1).repeat(1, self.sequence_length - 1, 1)
        split_tensor = torch.split(x, [1, self.sequence_length - 1], dim=1)
        node_tensor = split_tensor[0]
        neighbor_tensor = split_tensor[1]
        layer_atten = self.attn_layer(torch.cat((target, neighbor_tensor), dim=2))
        layer_atten = F.softmax(layer_atten, dim=1)
        neighbor_tensor = neighbor_tensor * layer_atten
        neighbor_tensor = torch.sum(neighbor_tensor, dim=1, keepdim=True)
        output = (node_tensor + neighbor_tensor).squeeze(1)
        output = self.out_proj(output)
        return output


class TransformerNodeClassifier(TransformerEncoder):
    save_freq = 5

    def __init__(self, stems, backbone, num_hops, num_features, num_classes, device, state_dict=None):
        super().__init__(stems, backbone, num_hops=num_hops, device=device)
        self.training_type = None
        self.lr_scheduler = None
        self.optimizer = None
        self.pretrained = False
        if state_dict is not None:
            self.load_state_dict(state_dict)
            self.pretrained = True
        stem_id = [stem['num_node_f'] for stem in self.configs['stems']].index(num_features)
        self.stems = self.stems[stem_id]
        self.predictor = torch.nn.Linear(int(self.hidden_dim / 2), num_classes)

    def forward(self, x):
        # stem
        x = self.stems(x)

        # transformer layers
        for backbone_layer in self.backbone:
            x = backbone_layer(x)

        # layer normalisation
        x = self.final_ln(x)

        # attention-based readout
        target = x[:, 0, :].unsqueeze(1).repeat(1, self.sequence_length - 1, 1)
        split_tensor = torch.split(x, [1, self.sequence_length - 1], dim=1)
        node_tensor = split_tensor[0]
        neighbor_tensor = split_tensor[1]
        layer_atten = self.attn_layer(torch.cat((target, neighbor_tensor), dim=2))
        layer_atten = F.softmax(layer_atten, dim=1)
        neighbor_tensor = neighbor_tensor * layer_atten
        neighbor_tensor = torch.sum(neighbor_tensor, dim=1, keepdim=True)
        output = (node_tensor + neighbor_tensor).squeeze(1)
        output = self.out_proj(output)
        output = self.predictor(output)
        return output

    def train_model(self, dataset, loss, optimizer, num_epochs, freeze_encoder, save_path=None, verbose=50,
                    lr_adapt=True):
        self.optimizer = optimizer
        if freeze_encoder:
            self.freeze_encoder()
        self.train()
        self.to(self.device_)
        dataset.data.to(self.device_)
        if lr_adapt:
            self.lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, verbose=True, patience=50)
        best_loss = 1e2
        print('starting training of classifier')
        start_time = time.time()
        for epoch in range(num_epochs):
            # zero grad
            self.optimizer.zero_grad()

            # forward
            mask = dataset.data.train_mask
            preds = self.forward(x=dataset.data.x)
            train_loss = loss(preds[mask], dataset.data.y[mask])

            # backward and update
            train_loss.backward()
            self.optimizer.step()

            # save
            if save_path is not None and ((epoch + 1) % self.save_freq == 0 or (epoch + 1) == num_epochs):
                mask = dataset.data.val_mask
                val_loss = loss(preds[mask], dataset.data.y[mask]).item()
                if val_loss < best_loss:
                    best_loss = val_loss
                    self.save_model(epoch, save_path)

            # print
            if (epoch + 1) % verbose == 0 or epoch == 0 or (epoch + 1) == num_epochs:
                print(
                    '[' + datetime.datetime.now().strftime(format='%d.%m.%y %H:%M:%S') + ']' +
                    ' epoch: ' + str(epoch + 1) +
                    '; classification_loss: ' + str(round(train_loss.item(), 4))
                )

            # lr schedule
            if self.lr_scheduler is not None:
                self.lr_scheduler.step(train_loss)

        # print
        print('training completed. elapsed time: ', str(time.time() - start_time))

        eval_loss, eval_acc = self.evaluate_model(dataset, loss)
        print('evaluation loss: ', str(round(eval_loss, 4)))
        print('evaluation accuracy: ', str(round(eval_acc, 4)))

        return eval_loss, eval_acc

    def freeze_encoder(self):
        for p in self.stems.parameters():
            p.requires_grad = False
        for p in self.backbone.parameters():
            p.requires_grad = False

    def save_model(self, epoch, save_path):
        save_dict = {
            'model_state_dict': self.state_dict(),
            'optimizer': {
                'type': 'adam',
                'state_dict': self.optimizer.state_dict(),
            },
            'configs': self.configs,
            'epoch': epoch,
        }
        torch.save(save_dict, save_path)

    def evaluate_model(self, dataset, loss):
        self.eval()
        mask = dataset.data.test_mask
        preds = self.forward(x=dataset.data.x)
        eval_loss = loss(preds[mask], dataset.data.y[mask])
        y_hat = preds.argmax(dim=1)
        test_correct = y_hat[mask] == dataset.data.y[mask]
        test_acc = int(test_correct.sum()) / int(mask.sum())
        return eval_loss.item(), test_acc


class TransformerLinkPredictor(TransformerEncoder):
    save_freq = 5

    def __init__(self, stems, backbone, num_hops, num_features, device, state_dict=None):
        super().__init__(stems, backbone, num_hops=num_hops, device=device)
        self.training_type = None
        self.lr_scheduler = None
        self.optimizer = None
        self.pretrained = False
        if state_dict is not None:
            self.load_state_dict(state_dict)
            self.pretrained = True
        stem_id = [stem['num_node_f'] for stem in self.configs['stems']].index(num_features)
        self.stems = self.stems[stem_id]
        self.predictor = torch.nn.Linear(int(self.hidden_dim / 2), 1)

    def forward(self, x):
        x, edge_label_index = x
        x = self.stems(x)

        # transformer layers
        for backbone_layer in self.backbone:
            x = backbone_layer(x)

        # layer normalisation
        x = self.final_ln(x)

        # attention-based readout
        target = x[:, 0, :].unsqueeze(1).repeat(1, self.sequence_length - 1, 1)
        split_tensor = torch.split(x, [1, self.sequence_length - 1], dim=1)
        node_tensor = split_tensor[0]
        neighbor_tensor = split_tensor[1]
        layer_atten = self.attn_layer(torch.cat((target, neighbor_tensor), dim=2))
        layer_atten = F.softmax(layer_atten, dim=1)
        neighbor_tensor = neighbor_tensor * layer_atten
        neighbor_tensor = torch.sum(neighbor_tensor, dim=1, keepdim=True)
        output = (node_tensor + neighbor_tensor).squeeze(1)
        output = self.out_proj(output)
        output = self.predictor(output)

        # link prediction
        source_embeddings = output[edge_label_index[0]]
        target_embeddings = output[edge_label_index[1]]
        logits = (source_embeddings * target_embeddings).sum(dim=-1)
        return logits

    def train_model(self, dataset, loss, optimizer, num_epochs, freeze_encoder, save_path=None, verbose=50,
                    lr_adapt=True):
        self.optimizer = optimizer
        if freeze_encoder:
            self.freeze_encoder()
        self.train()
        self.to(self.device_)
        dataset.data.to(self.device_)
        if lr_adapt:
            self.lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, verbose=True, patience=50)
        best_loss = 1e2
        print('starting training of classifier')
        start_time = time.time()
        for epoch in range(num_epochs):
            # zero grad
            self.optimizer.zero_grad()

            # forward
            for batch, data_batch in enumerate(dataset.train_loader):
                preds = self.forward(x=(data_batch.x, data_batch.edge_label_index))
                train_loss = loss(preds, data_batch.y)

                # backward and update
                train_loss.backward()
                self.optimizer.step()

            # save
            if save_path is not None and (epoch + 1) % self.save_freq == 0:
                val_loss = self.evaluate_model(dataset.val_loader)
                if val_loss < best_loss:
                    best_loss = val_loss
                    self.save_model(epoch, save_path)

                # print
                if (epoch + 1) % verbose == 0 or epoch == 0 or (epoch + 1) == num_epochs:
                    print(
                        '[' + datetime.datetime.now().strftime(format='%d.%m.%y %H:%M:%S') + ']' +
                        ' epoch: ' + str(epoch + 1) +
                        '; classification_loss: ' + str(round(train_loss.item(), 4))
                    )

            # lr schedule
            if self.lr_scheduler is not None:
                self.lr_scheduler.step(train_loss)

        # print
        print('training completed. elapsed time: ', str(time.time() - start_time))

        eval_loss, eval_acc = self.evaluate_model(dataset, loss)
        print('evaluation loss: ', str(round(eval_loss, 4)))
        print('evaluation accuracy: ', str(round(eval_acc, 4)))

        return eval_loss, eval_acc

    def freeze_encoder(self):
        for p in self.stems.parameters():
            p.requires_grad = False
        for p in self.backbone.parameters():
            p.requires_grad = False


class SSLTransformer(torch.nn.Module):
    base_lr = 1e-3
    save_freq = 20
    task_weights = {
        'graph partitioning': 0.9,
        'deep graph infomax': 0.65,
        'pair-wise attribute similarity': 0.65,
        'pair-wise distance': 0.1,
        'clustering': 0.1,
    }

    def __init__(self, encoder, data, ssl_tasks):
        super().__init__()
        # model
        self.encoder = encoder
        self.encoder.to(self.encoder.device_)

        # data
        self.data = data

        # ssl tasks
        self.ssl_tasks = ssl_tasks
        self.ssl_objects = self.instantiate_ssl()

        # optimization
        self.optimizer = torch.optim.Adam(self.parameters(), lr=self.base_lr)
        self.head_optimizers = [torch.optim.Adam(ssl_object.predictor.parameters(), lr=self.base_lr)
                                for ssl_object in self.ssl_objects if ssl_object.name != 'graph partitioning']
        if 'partition' in self.ssl_tasks:
            partition_task = [ssl_object for ssl_object in self.ssl_objects if ssl_object.name == 'graph partitioning']
            partition_task = partition_task[0]
            partition_optimizers = [torch.optim.Adam(predictor.parameters(), lr=self.base_lr)
                                    for predictor in partition_task.predictor]
            self.head_optimizers = [*self.head_optimizers, *partition_optimizers]
        self.plateau_lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(self.optimizer, patience=50,
                                                                               factor=1 / 3., verbose=True)
        self.lin_lr_scheduler = torch.optim.lr_scheduler.LinearLR(self.optimizer)
        self.lr_scheduler = [self.plateau_lr_scheduler, self.lin_lr_scheduler]

    def instantiate_ssl(self):
        ssl_objects = []
        embedding_size = self.encoder.out_proj.out_features
        if 'clustering' in self.ssl_tasks:
            ssl_objects.append(
                clustering.ClusteringTask(data=self.data, embedding_size=embedding_size, device=self.encoder.device_)
            )
        if 'pairsim' in self.ssl_tasks:
            ssl_objects.append(
                pairsim.PairwiseAttrSimTask(data=self.data, embedding_size=embedding_size, device=self.encoder.device_)
            )
        if 'dgi' in self.ssl_tasks:
            ssl_objects.append(
                dgi.DGITask(data=self.data, encoder=self.encoder, embedding_size=embedding_size,
                            device=self.encoder.device_)
            )
        if 'partition' in self.ssl_tasks:
            ssl_objects.append(
                partition.PartitionTask(data=self.data, embedding_size=embedding_size, device=self.encoder.device_)
            )
        if 'pairdis' in self.ssl_tasks:
            ssl_objects.append(
                pairdis.PairwiseDistanceTask(data=self.data, embedding_size=embedding_size, device=self.encoder.device_)
            )
        return ssl_objects

    def pretrain(self, num_epochs=1000, weighted=False, lr_adapt=True, verbose_freq=50, batched=False, save_path=None):
        if lr_adapt is False:
            self.lr_scheduler = None

        if save_path is not None:
            with open(os.path.join(save_path['folder'], 'config.txt'), 'w') as f:
                f.write(json.dumps(self.encoder.configs))
                dataset_names = [dataset.name for dataset in self.data.datasets]
                f.write("\nepochs: " + str(num_epochs))
                f.write("\ndatasets: " + "; ".join(dataset_names))
                f.write("\nmodel type: node aggregation graph transformer")
                f.write("\n" + json.dumps({'weighted': weighted}))

        best_loss = 1e2
        self.train()
        print('starting pre-training')
        start_time = time.time()
        printed = False
        for epoch in range(num_epochs):
            # zero grad
            self.optimizer.zero_grad()
            [head_optimizer.zero_grad() for head_optimizer in self.head_optimizers]

            # forward + gradients
            for dataset in self.data.datasets:
                # embeddings
                x = dataset.data.x.to(self.encoder.device_)
                x = self.encoder(x=x)

                # ssl
                ssl_loss = 0
                for ssl_object in self.ssl_objects:
                    if weighted:
                        ssl_loss = ssl_loss +\
                                   self.task_weights[ssl_object.name] * ssl_object.get_loss(x, dataset.name)
                    else:
                        ssl_loss = ssl_loss + ssl_object.get_loss(x, dataset.name)

                # gradients
                ssl_loss.backward()

            # update
            self.optimizer.step()
            [head_optimizer.step() for head_optimizer in self.head_optimizers]

            # save and overwrite model
            if ssl_loss.item() < best_loss and (epoch + 1) % self.save_freq == 0:
                best_loss = copy.deepcopy(ssl_loss.item())
                if save_path is not None:
                    self.save_model(epoch, save_path['file'])
                else:
                    if not printed:
                        print('no save_path provided. skipping saving')
                        printed = True

            # print
            if (epoch + 1) % verbose_freq == 0 or epoch == 0 or epoch == num_epochs - 1:
                print(
                    '[' + datetime.datetime.now().strftime(format='%d.%m.%y %H:%M:%S') + ']' +
                    ' epoch: ' + str(epoch + 1) +
                    '; ssl_loss: ' + str(round(ssl_loss.item(), 4))
                )

            # lr schedule
            if self.plateau_lr_scheduler is not None:
                self.adjust_lr(ssl_loss)

        # print
        print('pre-training completed. elapsed time: ', str(time.time() - start_time))

    def save_model(self, epoch, save_path):
        save_dict = {
            'data': [dataset.name for dataset in self.data.datasets],
            'model_state_dict': self.encoder.state_dict(),
            'optimizer': {
                'type': 'adam',
                'base_lr': self.base_lr,
                'state_dicts': self.optimizer.state_dict(),
            },
            'configs': self.encoder.configs,
            'num_hops': self.encoder.sequence_length - 1,
            'ssl_tasks': self.ssl_tasks,
            'epoch': epoch,
        }
        torch.save(save_dict, save_path)

    def adjust_lr(self, ssl_loss):
        self.lin_lr_scheduler.step()
        self.plateau_lr_scheduler.step(ssl_loss)
